import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from AlexNet1.AlexNetModel import *

def main():
    (x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
    x_train = x_train.reshape((-1, 28, 28, 1))
    x_test = x_test.reshape((-1, 28, 28, 1))
    AlexNet_train(x_train,y_train,x_test,y_test)
    print(x_train.shape[1], x_train.shape[2], x_train.shape[3])

if __name__ == '__main__':
    main()
